{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c5753a0c",
   "metadata": {},
   "source": [
    "# L5: Chat with any LLM! 💬"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a01a3724",
   "metadata": {},
   "source": [
    "Load your HF API key and relevant Python libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fa6fa00-6bd1-4839-bcaf-8bae9267ee79",
   "metadata": {
    "height": 199
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import io\n",
    "import IPython.display\n",
    "from PIL import Image\n",
    "import base64 \n",
    "import requests \n",
    "requests.adapters.DEFAULT_TIMEOUT = 60\n",
    "\n",
    "from dotenv import load_dotenv, find_dotenv\n",
    "_ = load_dotenv(find_dotenv()) # read local .env file\n",
    "hf_api_key = os.environ['HF_API_KEY']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "095da8fe-24aa-4dc7-8e08-aa2f949ae21f",
   "metadata": {
    "height": 131
   },
   "outputs": [],
   "source": [
    "# Helper function\n",
    "import requests, json\n",
    "from text_generation import Client\n",
    "\n",
    "#FalcomLM-instruct endpoint on the text_generation library\n",
    "client = Client(os.environ['HF_API_FALCOM_BASE'], headers={\"Authorization\": f\"Basic {hf_api_key}\"}, timeout=120)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bfe6fc97",
   "metadata": {},
   "source": [
    "## Building an app to chat with any LLM"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "745a3c9b",
   "metadata": {},
   "source": [
    "Here we'll be using an [Inference Endpoint](https://huggingface.co/inference-endpoints) for `falcon-40b-instruct` , the best ranking open source LLM on the [🤗 Open LLM Leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7065860-3c0b-490d-9e7c-22e5b79fc004",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "prompt = \"Has math been invented or discovered?\"\n",
    "client.generate(prompt, max_new_tokens=256).generated_text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dcb659e-b71b-46da-b9d2-6ee62498995f",
   "metadata": {
    "height": 301
   },
   "outputs": [],
   "source": [
    "#Back to Lesson 2, time flies!\n",
    "import gradio as gr\n",
    "def generate(input, slider):\n",
    "    output = client.generate(input, max_new_tokens=slider).generated_text\n",
    "    return output\n",
    "\n",
    "demo = gr.Interface(fn=generate, \n",
    "                    inputs=[gr.Textbox(label=\"Prompt\"), \n",
    "                            gr.Slider(label=\"Max new tokens\", \n",
    "                                      value=20,  \n",
    "                                      maximum=1024, \n",
    "                                      minimum=1)], \n",
    "                    outputs=[gr.Textbox(label=\"Completion\")])\n",
    "\n",
    "gr.close_all()\n",
    "demo.launch(share=True, server_port=int(os.environ['PORT1']))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e5f55e2",
   "metadata": {},
   "source": [
    "## `gr.Chatbot()`\n",
    "\n",
    "- `gr.Chatbot()` allows you to save the chat history (between the user and the LLM) as well as display the dialogue in the app.\n",
    "- Define your `fn` to take in a `gr.Chatbot()` object.  \n",
    "  - Within your defined `fn` function, append a tuple (or a list) containing the user message and the LLM's response:\n",
    "`chatbot_object.append( (user_message, llm_message) )`\n",
    "\n",
    "- Include the chatbot object in both the inputs and the outputs of the app."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43beebb7-40a6-4af5-a701-882821b6ed36",
   "metadata": {
    "height": 386
   },
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "def respond(message, chat_history):\n",
    "        #No LLM here, just respond with a random pre-made message\n",
    "        bot_message = random.choice([\"Tell me more about it\", \n",
    "                                     \"Cool, but I'm not interested\", \n",
    "                                     \"Hmmmm, ok then\"]) \n",
    "        chat_history.append((message, bot_message))\n",
    "        return \"\", chat_history\n",
    "\n",
    "with gr.Blocks() as demo:\n",
    "    chatbot = gr.Chatbot(height=240) #just to fit the notebook\n",
    "    msg = gr.Textbox(label=\"Prompt\")\n",
    "    btn = gr.Button(\"Submit\")\n",
    "    clear = gr.ClearButton(components=[msg, chatbot], value=\"Clear console\")\n",
    "\n",
    "    btn.click(respond, inputs=[msg, chatbot], outputs=[msg, chatbot])\n",
    "    msg.submit(respond, inputs=[msg, chatbot], outputs=[msg, chatbot]) #Press enter to submit\n",
    "\n",
    "gr.close_all()\n",
    "demo.launch(share=True, server_port=int(os.environ['PORT2']))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8646d777-c211-4d31-9426-7b5d78b533ae",
   "metadata": {},
   "source": [
    "#### Format the prompt with the chat history\n",
    "\n",
    "- You can iterate through the chatbot object with a for loop.\n",
    "- Each item is a tuple containing the user message and the LLM's message.\n",
    "\n",
    "```Python\n",
    "for turn in chat_history:\n",
    "    user_msg, bot_msg = turn\n",
    "    ...\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55bae99d-7a63-4a40-bab7-de7d10b8ab1b",
   "metadata": {
    "height": 488
   },
   "outputs": [],
   "source": [
    "def format_chat_prompt(message, chat_history):\n",
    "    prompt = \"\"\n",
    "    for turn in chat_history:\n",
    "        user_message, bot_message = turn\n",
    "        prompt = f\"{prompt}\\nUser: {user_message}\\nAssistant: {bot_message}\"\n",
    "    prompt = f\"{prompt}\\nUser: {message}\\nAssistant:\"\n",
    "    return prompt\n",
    "\n",
    "def respond(message, chat_history):\n",
    "        formatted_prompt = format_chat_prompt(message, chat_history)\n",
    "        bot_message = client.generate(formatted_prompt,\n",
    "                                     max_new_tokens=1024,\n",
    "                                     stop_sequences=[\"\\nUser:\", \"<|endoftext|>\"]).generated_text\n",
    "        chat_history.append((message, bot_message))\n",
    "        return \"\", chat_history\n",
    "\n",
    "with gr.Blocks() as demo:\n",
    "    chatbot = gr.Chatbot(height=240) #just to fit the notebook\n",
    "    msg = gr.Textbox(label=\"Prompt\")\n",
    "    btn = gr.Button(\"Submit\")\n",
    "    clear = gr.ClearButton(components=[msg, chatbot], value=\"Clear console\")\n",
    "\n",
    "    btn.click(respond, inputs=[msg, chatbot], outputs=[msg, chatbot])\n",
    "    msg.submit(respond, inputs=[msg, chatbot], outputs=[msg, chatbot]) #Press enter to submit\n",
    "\n",
    "gr.close_all()\n",
    "demo.launch(share=True, server_port=int(os.environ['PORT3']))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f22b8de8",
   "metadata": {},
   "source": [
    "### Adding other advanced features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e4fff81-a3d1-4cb8-8d6e-d152ab39065a",
   "metadata": {
    "height": 148
   },
   "outputs": [],
   "source": [
    "def format_chat_prompt(message, chat_history, instruction):\n",
    "    prompt = f\"System:{instruction}\"\n",
    "    for turn in chat_history:\n",
    "        user_message, bot_message = turn\n",
    "        prompt = f\"{prompt}\\nUser: {user_message}\\nAssistant: {bot_message}\"\n",
    "    prompt = f\"{prompt}\\nUser: {message}\\nAssistant:\"\n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3ee9bc5-fce7-44b1-af2a-e69bc7c598b6",
   "metadata": {},
   "source": [
    "### Streaming\n",
    "\n",
    "- If your LLM can provide its tokens one at a time in a stream, you can accumulate those tokens in the chatbot object.\n",
    "- The `for` loop in the following function goes through all the tokens that are in the stream and appends them to the most recent conversational turn in the chatbot's message history."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "700eb3bc-b63a-4ccb-94c4-70ec2e54bcda",
   "metadata": {
    "height": 454
   },
   "outputs": [],
   "source": [
    "def respond(message, chat_history, instruction, temperature=0.7):\n",
    "    prompt = format_chat_prompt(message, chat_history, instruction)\n",
    "    chat_history = chat_history + [[message, \"\"]]\n",
    "    stream = client.generate_stream(prompt,\n",
    "                                      max_new_tokens=1024,\n",
    "                                      stop_sequences=[\"\\nUser:\", \"<|endoftext|>\"],\n",
    "                                      temperature=temperature)\n",
    "                                      #stop_sequences to not generate the user answer\n",
    "    acc_text = \"\"\n",
    "    #Streaming the tokens\n",
    "    for idx, response in enumerate(stream):\n",
    "            text_token = response.token.text\n",
    "\n",
    "            if response.details:\n",
    "                return\n",
    "\n",
    "            if idx == 0 and text_token.startswith(\" \"):\n",
    "                text_token = text_token[1:]\n",
    "\n",
    "            acc_text += text_token\n",
    "            last_turn = list(chat_history.pop(-1))\n",
    "            last_turn[-1] += acc_text\n",
    "            chat_history = chat_history + [last_turn]\n",
    "            yield \"\", chat_history\n",
    "            acc_text = \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09873dfd-5b6c-41d6-9479-12e8c8894295",
   "metadata": {
    "height": 267
   },
   "outputs": [],
   "source": [
    "with gr.Blocks() as demo:\n",
    "    chatbot = gr.Chatbot(height=240) #just to fit the notebook\n",
    "    msg = gr.Textbox(label=\"Prompt\")\n",
    "    with gr.Accordion(label=\"Advanced options\",open=False):\n",
    "        system = gr.Textbox(label=\"System message\", lines=2, value=\"A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.\")\n",
    "        temperature = gr.Slider(label=\"temperature\", minimum=0.1, maximum=1, value=0.7, step=0.1)\n",
    "    btn = gr.Button(\"Submit\")\n",
    "    clear = gr.ClearButton(components=[msg, chatbot], value=\"Clear console\")\n",
    "\n",
    "    btn.click(respond, inputs=[msg, chatbot, system], outputs=[msg, chatbot])\n",
    "    msg.submit(respond, inputs=[msg, chatbot, system], outputs=[msg, chatbot]) #Press enter to submit\n",
    "\n",
    "gr.close_all()\n",
    "demo.queue().launch(share=True, server_port=int(os.environ['PORT4']))    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4a51a07",
   "metadata": {},
   "source": [
    "Notice, in the cell above, you have used `demo.queue().launch()` instead of `demo.launch()`. \"queue\" helps you to boost up the performance for your demo. You can read [setting up a demo for maximum performance](https://www.gradio.app/guides/setting-up-a-demo-for-maximum-performance) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d9ec80a-39ad-4f58-b79e-4f413c5074c0",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "gr.close_all()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12cf9b3a-4202-4e3a-9c6b-941fa1290ab8",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
